import random

import torch

from tools.utils import dim_action_space


class Memory():
    def __init__(self, args, obs_size, action_space, context,mode=0):
        self.context = context
        self.args = args
        dtype_obs = torch.uint8 if self.args.env_type == "ram" or self.args.image else torch.float
        dtype_act = torch.long if action_space.__class__.__name__ == 'Discrete' else torch.float

        device="cpu"
        self.next_obs = torch.zeros((0,args.buffer_size, *obs_size), dtype=dtype_obs,device=device,requires_grad=False)
        self.obs = torch.zeros((0,args.buffer_size, *obs_size), dtype=dtype_obs,device=device,requires_grad=False)
        self.rewards = torch.zeros((0,args.buffer_size, 1), device=device,requires_grad=False)
        self.masks = torch.ones((0,args.buffer_size, 1), device=device,requires_grad=False)
        self.actions = torch.empty((0,args.buffer_size, dim_action_space(action_space)),device=device,dtype=dtype_act,requires_grad=False)
        self.goals = torch.zeros((0,args.buffer_size, self.args.num_latents), device=device,requires_grad=False)
        self.goals_step = torch.zeros((0,args.buffer_size,1), dtype=torch.long, device=device,requires_grad=False)
        self.goals_obs = torch.zeros((0,args.buffer_size,*obs_size), dtype=dtype_obs,device=device,requires_grad=False)
        if mode != 0 :
            self.embeds = torch.zeros((0,args.buffer_size, self.args.num_latents), device=device,requires_grad=False)
        if self.args.state:
            self.states = torch.zeros((0,args.buffer_size, 2), device=device,requires_grad=False)
            self.prev_states = torch.zeros((0,args.buffer_size, 2), device=device,requires_grad=False)


class DoubleMemoryDataStore:
    def __init__(self, *args, **kwargs):
        self.learnDataStore=DataStore(*args,**kwargs,mode=1,double=self)
        self.correctDataStore=DataStore(*args,**kwargs,mode=0,double=self)
        self.args= self.learnDataStore.args
        self.context=self.learnDataStore.context
        self.select_indexes = torch.tensor([False]*self.args.buffer_size,dtype=torch.bool)
        self.cpt_select_indexes = 0

        self.cpt_removed_index = 0
        self.cpt_add_index = 0
        self.cpt_goal_removed_index = 0
        self.cpt_goal_add_index = 0
        self.original_insertions =0
        self.id = self.correctDataStore.id
        self.neighbors = []

    def reduce_sid(self):
        self.learnDataStore.reduce_sid()
        self.correctDataStore.reduce_sid()

    def delete(self,*args,**kwargs):
        self.learnDataStore.delete(*args,**kwargs)
        self.correctDataStore.delete(*args,**kwargs)
        for i in range(self.correctDataStore.id,len(self.correctDataStore.buffer.buffers)):
            if not self.correctDataStore.buffer.buffers[i].is_deleted():
                self.correctDataStore.buffer.buffers[i].reduce_sid()
        self.cpt_select_indexes = 0
        self.select_indexes=None
        self.goal_select_indexes=None

    def insert(self,*args,mode="goal",goal_step=None,**kwargs):
        if mode == "goal":
            ind = self.correctDataStore.insert(*args,goal_step=goal_step,**kwargs)
        else:
            ind = self.learnDataStore.insert(*args,goal_step=goal_step,**kwargs)
            self.original_insertions+=1
            if self.select_indexes[ind]:
                self.cpt_select_indexes -= 1

            self.select_indexes[ind:ind+1]=True
            self.cpt_select_indexes +=1

        return ind

    def maj_select_index(self,ind,win):
        if self.select_indexes[ind]:
            self.cpt_select_indexes -= 1

        if win:
            self.select_indexes[ind:ind+1]=True
            self.cpt_select_indexes += 1
        else:
            self.select_indexes[ind:ind + 1] = False
            self.cpt_removed_index += 1

    def get_embeds_samples(self):
        sid = self.learnDataStore.sid
        if self.args.state:
            return self.context.memory2.states[sid,self.select_indexes]
        return self.context.memory2.embeds[sid,self.select_indexes]

    def get_obs_to_embed(self,selection,state=False):
        sid = self.learnDataStore.sid
        if state:
            self.tmp_state = self.context.memory2.states[sid, self.select_indexes][selection].view(1,-1)
        return self.context.memory2.next_obs[sid, self.select_indexes][selection]

    def is_deleted(self):
        return self.learnDataStore.is_deleted()

    @property
    def insertions(self):
        return self.learnDataStore.insertions if not self.is_deleted() else -10000

    @property
    def total_size(self):
        return self.learnDataStore.total_size + self.correctDataStore.total_size

    @property
    def max_goalstep(self):
        return max(self.learnDataStore.max_goalstep,self.correctDataStore.max_goalstep)

    @property
    def select_size(self):
        return self.cpt_select_indexes

    def available(self):
        return self.total_size > 0 and not self.is_deleted()

    def sample(self,batch,i,**kwargs):
        if not self.correctDataStore.available() or self.args.type != 0:
            sid,ind=self.learnDataStore.sample()
            bid=1
        elif not self.learnDataStore.available():
            sid,ind=self.correctDataStore.sample()
            bid=0
        elif random.randint(0,1)==0:
            sid,ind=self.learnDataStore.sample()
            bid=1
        else:
            sid,ind=self.correctDataStore.sample()
            bid=0

        ###Get relabeling obses
        if (self.learnDataStore.args.relabeling == 4 and bid==1):
            keys = self.neighbors
            for _ in range(self.args.plan_steps):
                k= random.randint(0,len(keys)-1)
                neighbor = self.learnDataStore.buffer.buffers[keys[k]].learnDataStore
                keys=self.learnDataStore.buffer.buffers[keys[k]].neighbors
            cpt=0
            while not neighbor.available():
                cpt+=1
                if cpt == 5:
                    batch.label_obs[i] = self.context.memory2.next_obs[sid,ind] if bid == 1 else self.context.memory.next_obs[sid,ind]
                    return sid,ind,self.id,bid
                for _ in range(self.args.plan_steps):
                    k = random.randint(0, len(keys) - 1)
                    neighbor = self.learnDataStore.buffer.buffers[keys[k]].learnDataStore
                    keys = self.learnDataStore.buffer.buffers[keys[k]].neighbors
            batch.label_obs[i] = neighbor.sample_obs()
        elif (self.learnDataStore.args.relabeling == 1 and bid==1) or (self.args.relabeling2 and bid == 0):
            from random import shuffle
            keys = [self.learnDataStore.id]+list(self.learnDataStore.buffer.neighbors(self.learnDataStore.id))
            shuffle(keys)
            k=0
            neighbor = self.learnDataStore.buffer.buffers[keys[k]].learnDataStore
            while not neighbor.available():
                k=k+1
                if k == len(keys):
                    batch.label_obs[i] = self.context.memory2.next_obs[sid,ind] if bid == 1 else self.context.memory.next_obs[sid,ind]
                    return sid,ind,self.id,bid
                neighbor = self.learnDataStore.buffer.buffers[keys[k]].learnDataStore
            batch.label_obs[i] = neighbor.sample_obs()
        elif self.learnDataStore.args.relabeling == 2 and bid==1:
            k = self.sample_cluster()
            neighbor = self.learnDataStore.buffer.buffers[k].learnDataStore
            j=0

            while j < 3 and not neighbor.available():
                k = self.sample_cluster()
                neighbor = self.learnDataStore.buffer.buffers[k].learnDataStore
                j=j+1
            if j == 3:
                batch.label_obs[i] = self.context.memory2.next_obs[sid,ind]
                return sid,ind,self.id,bid
            batch.label_obs[i] = neighbor.sample_obs()

        return sid,ind,self.id,bid

    def sample_cluster(self):
        if self.learnDataStore.buffer.context.coord_actor.dist_weights is not None:
            key = self.learnDataStore.buffer.context.coord_actor.dist_weights.sample().item()
            return key
        return self.learnDataStore.buffer.buffers.random_key()


class DataStore:
    def __init__(self, id, args, obs_size, action_space, context, buffer,mode=0,double=None):
        self.target_memory = context.memory if mode == 0 else context.memory2
        self.buffer = buffer
        self.context = context
        self.id = id
        self.args = args
        self.ch_goal = None
        self.mode=mode
        dtype_obs = torch.uint8 if self.args.env_type == "ram" or self.args.image else torch.float
        dtype_act = torch.long if action_space.__class__.__name__ == 'Discrete' else torch.float
        device = "cpu"
        self.target_memory.next_obs =torch.cat((self.target_memory.next_obs, torch.zeros((1,args.buffer_size, *obs_size), dtype=dtype_obs,device=device,requires_grad=False)),dim=0)
        self.target_memory.obs =torch.cat((self.target_memory.obs, torch.zeros((1,args.buffer_size, *obs_size), dtype=dtype_obs,device=device,requires_grad=False)),dim=0)
        self.target_memory.rewards = torch.cat((self.target_memory.rewards, torch.zeros((1,args.buffer_size, 1), device=device,requires_grad=False)),dim=0)
        self.target_memory.masks = torch.cat((self.target_memory.masks, torch.ones((1,args.buffer_size, 1), device=device,requires_grad=False)),dim=0)
        self.target_memory.actions = torch.cat((self.target_memory.actions,torch.empty((1,args.buffer_size, dim_action_space(action_space)),device=device,dtype=dtype_act,requires_grad=False)),dim=0)
        self.target_memory.goals = torch.cat((self.target_memory.goals, torch.zeros((1,args.buffer_size, self.args.num_latents), device=device,requires_grad=False)),dim=0)
        self.target_memory.goals_step = torch.cat((self.target_memory.goals_step, torch.zeros((1,args.buffer_size,1), dtype=torch.long, device=device,requires_grad=False)),dim=0)
        self.target_memory.goals_obs= torch.cat((self.target_memory.goals_obs,torch.zeros((1,args.buffer_size,*obs_size), dtype=dtype_obs,device=device,requires_grad=False)),dim=0)
        if self.mode != 0:
            self.target_memory.embeds = torch.cat((self.target_memory.embeds, torch.zeros((1,args.buffer_size, self.args.num_latents), device=device,requires_grad=False)),dim=0)
        if self.args.state:
            self.target_memory.states = torch.cat((self.target_memory.states, torch.zeros((1,args.buffer_size, 2), dtype=dtype_obs,device=device,requires_grad=False)),dim=0)
            self.target_memory.prev_states = torch.cat((self.target_memory.prev_states, torch.zeros((1,args.buffer_size, 2), dtype=dtype_obs,device=device,requires_grad=False)),dim=0)

        self.step = 0
        self.total_size = 0
        self.insertions = 0
        self.deleted = False
        self.sid = self.target_memory.obs.shape[0]-1
        self.double=double

    def reduce_sid(self):
        self.sid -=1

    def available(self):
        return self.total_size > 0 and not self.is_deleted()

    def insert(self, obs, next_obs, actions, masks, reward, *args, goal=None, goal_step=None, max_substeps=None,
               goal_obs=None,embed=None,state=None,prev_state=None,**kwargs):
        self.insertions += 1

        if self.mode == 0 or self.total_size <= self.args.buffer_size -1:
            if self.total_size == self.args.buffer_size:
                stepit = random.randint(0, self.total_size - 1)
            else:
                stepit = self.step
        else:
            stepit = random.randint(0,self.total_size-1)

        next_step = (self.step + 1) % self.args.buffer_size
        self.target_memory.obs[self.sid,stepit].copy_(obs)
        self.target_memory.actions[self.sid,stepit].copy_(actions)
        self.target_memory.rewards[self.sid,stepit].copy_(reward)
        self.target_memory.masks[self.sid,stepit].copy_(masks)
        self.target_memory.goals[self.sid,stepit].copy_(goal.squeeze())
        self.target_memory.goals_step[self.sid,stepit:stepit+1] = goal_step+1

        if self.mode != 0 :
            self.target_memory.embeds[self.sid,stepit] = embed

        if self.args.state:
            self.target_memory.states[self.sid,stepit] = state
            self.target_memory.prev_states[self.sid,stepit] = prev_state


        self.target_memory.goals_obs[self.sid,stepit].copy_(goal_obs)
        self.target_memory.next_obs[self.sid,stepit].copy_(next_obs)

        self.total_size = max(self.total_size, stepit+ 1)
        self.step = next_step
        return stepit

    def generate(self, *args,**kwargs):
        ind = random.randint(0, self.total_size-1)
        return self.target_memory.goals_step[self.sid,ind].item(), ind

    def can_learn(self, goal_step):
        return goal_step >= self.min_goalstep

    def sample_obs(self,state=False):
        _, ind = self.generate()
        if state:
            self.tmp_state = self.target_memory.states[self.sid,ind:ind+1, :].view(1,-1)
        return self.target_memory.next_obs[self.sid,ind:ind+1, :]

    def get_obs(self,gs,ind):
        return self.target_memory.next_obs[self.sid,ind:ind+1,:]

    def sample(self):
        _, ind = self.generate()
        return self.sid,ind

    def is_deleted(self):
        return self.deleted

    def delete(self, node_goal, *args, **kwargs):
        self.deleted = True
        mem = self.target_memory
        mem.obs = torch.cat((mem.obs[:self.sid],mem.obs[self.sid+1:]),dim=0)
        mem.next_obs = torch.cat((mem.next_obs[:self.sid],mem.next_obs[self.sid+1:]),dim=0)
        mem.actions = torch.cat((mem.actions[:self.sid],mem.actions[self.sid+1:]),dim=0)
        mem.rewards = torch.cat((mem.rewards[:self.sid],mem.rewards[self.sid+1:]),dim=0)
        mem.masks = torch.cat((mem.masks[:self.sid],mem.masks[self.sid+1:]),dim=0)
        mem.goals = torch.cat((mem.goals[:self.sid],mem.goals[self.sid+1:]),dim=0)
        mem.goals_step = torch.cat((mem.goals_step[:self.sid],mem.goals_step[self.sid+1:]),dim=0)
        mem.goals_obs = torch.cat((mem.goals_obs[:self.sid],mem.goals_obs[self.sid+1:]),dim=0)
        if self.mode != 0:
            mem.embeds = torch.cat((mem.embeds[:self.sid],mem.embeds[self.sid+1:]),dim=0)
        if self.args.state:
            mem.states = torch.cat((mem.states[:self.sid], mem.states[self.sid + 1:]), dim=0)
            mem.prev_states = torch.cat((mem.prev_states[:self.sid], mem.prev_states[self.sid + 1:]), dim=0)


    def load(self, datastore, node=None):
        path = self.context.load_model_path + str(self.id) + "datas.pt"
        self.min_goalstep = datastore['min_goalstep']
        self.max_goalstep = datastore['max_goalstep']
        self.insertions = datastore['insertions']
        self.deleted = datastore['deleted']

    def save(self):
        datastore = {}
        datastore['insertions'] = self.insertions
        datastore['max_goalstep'] = self.max_goalstep
        datastore['min_goalstep'] = self.min_goalstep
        datastore['deleted'] = self.deleted
        return datastore